#!/usr/bin/python
# -*- coding:utf-8 -*-
# @FileName : Test6(1).py
# Author    : myh

import torch

X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(X.sum(0, keepdim=True), '\n',X.sum(1, keepdim=True))
